# -*- coding: utf-8 -*- #
"""*********************************************************************************************"""
#   FileName     [ ark2voxceleb.py ]
#   Synopsis     [ process the .ark file preprocessed by kaldi for our dataloader ]
#   Author       [ Andy T. Liu (Andi611) ]
#   Copyright    [ Copyleft(c), Speech Lab, NTU, Taiwan ]
#   Reference    [ https://github.com/nttcslab-sp/kaldiio ]
"""*********************************************************************************************"""

###############
# IMPORTATION #
###############
import os
import pickle
import kaldi_io
import operator
import numpy as np
import pandas as pd
from tqdm import tqdm


############
# SETTINGS #
############
KALDI_PATH = os.path.join('../kaldi/egs/voxceleb/v1/data/') # this needs to be generated by the kaldi scripts
OUTPUT_DIR = '../data/voxceleb_mfcc_kaldi' 


############
# CONSTANT #
############
SETS = ['train'] # can be any subset of: ['train', 'dev', 'test']


########
# MAIN #
########
def main():
    if not os.path.isdir(KALDI_PATH):
        print('CHANGE THIS TO YOUR OWN KALDI PATH: ', KALDI_PATH)
        print('Please run the kaldi scripts first to generate kaldi data directory.')
        exit()

    if not os.path.isdir(OUTPUT_DIR):
        os.mkdir(OUTPUT_DIR)

    # read data from the preprocessed kaldi directory
    for s in SETS:
        print('Preprocessing', s, 'data...')
        output = {}
        cur_dir = os.path.join(OUTPUT_DIR, s)
        if not os.path.isdir(cur_dir): os.mkdir(cur_dir)

        path = os.path.join(KALDI_PATH, s + '/feats.scp') # kaldi data is already sorted
        for key, mat in tqdm(kaldi_io.read_mat_scp(path)): # (key, mat) is returned

            array = np.asarray(mat).astype('float32')
            np.save(os.path.join(cur_dir, key), array)
            output[os.path.join(s, key + '.npy')] = len(array)

        output = sorted(output.items(), key=operator.itemgetter(1), reverse=True)
        df = pd.DataFrame(data={'file_path':[fp for fp, l in output], 'length':[l for fp, l in output], 'label':'None'})
        df.to_csv(os.path.join(OUTPUT_DIR, s + '.csv'))

    print('[ARK-TO-VOXCELEB] - All done, saved at \'' + str(OUTPUT_DIR) + '\', exit.')
    exit()

if __name__ == '__main__':
    main()